# main.py
import argparse
import os
import csv
import numpy as np
from utils import set_seed
from data_loader import load_data
from embedding import compute_embeddings
from analysis_delta import analyze_delta_hidden_mi
from analysis import analyze_batch, visualize_information_flow
from transformers import AutoTokenizer
import torch
import gc

def run_epoch_analysis(epoch_checkpoint, args, device, tokenizer, vmin=None, vmax=None):
    print(f"\nProcessing epoch checkpoint: {epoch_checkpoint}")
    model_path = epoch_checkpoint
    # Load data
    prompts, combined_texts = load_data(args.input_file, sample_count=100, dataset_split="test", dataset_type="clutrr")
    
    predictions_csv_rows = []
    correct_context_lengths = []
    correct_prompts = []
    correct_combined_texts = []
    
    for i, (prompt, combined_text) in enumerate(zip(prompts, combined_texts)):
        context_length = len(tokenizer.encode(prompt, add_special_tokens=True))
        if context_length < 1:
            print(f"Sample {i} prompt too short, skipping.")
            continue
        combined_tokens = tokenizer.encode(combined_text, add_special_tokens=True)
        if len(combined_tokens) <= context_length:
            print(f"Sample {i} combined text too short, skipping.")
            continue

        # For analysis, include all samples
        correct_context_lengths.append(context_length)
        correct_prompts.append(prompt)
        correct_combined_texts.append(combined_text)
    
    aggregated_analysis = {}
    delta_hidden_mi_results = {}
    mi_matrix = None
    
    if correct_combined_texts:
        # Get hidden states for analysis
        hidden_states_y, input_ids, lm_head_weight = compute_embeddings(correct_combined_texts, device, tokenizer, model_path)
        hidden_states_z, _, _ = compute_embeddings(correct_prompts, device, tokenizer, model_path)

        # Standard information flow analysis
        if not args.collect_only:
            aggregated_analysis = analyze_batch(hidden_states_z, hidden_states_y, input_ids, correct_context_lengths, lm_head_weight, sigma=args.sigma, save_dir=args.fig_dir)
        
        # New delta hidden states mutual information analysis
        delta_h_save_dir = args.delta_h_dir if hasattr(args, "path/to/your/file") else os.path.join(args.fig_dir, "path/to/your/file")
        
        # If we're just collecting data for colorbar normalization, don't save visualizations
        save_dir = None if args.collect_only else delta_h_save_dir
        
        delta_hidden_mi_results = analyze_delta_hidden_mi(
            hidden_states_z=hidden_states_z,
            hidden_states_y=hidden_states_y,
            input_ids=input_ids,
            context_lengths=correct_context_lengths,
            lm_head_weight=lm_head_weight,
            sigma=args.sigma,
            save_dir=save_dir,
            epoch_checkpoint=os.path.basename(epoch_checkpoint),
            vmin=vmin,
            vmax=vmax
        )
        
        # Ensure we always get the MI matrix, even in collect_only mode
        if delta_hidden_mi_results and "mi_matrix" in delta_hidden_mi_results:
            mi_matrix = delta_hidden_mi_results["mi_matrix"]
            if args.collect_only:
                print(f"Successfully collected MI matrix with shape {mi_matrix.shape}")
                # Print min and max values of this matrix for debugging
                print(f"Matrix min: {np.min(mi_matrix)}, max: {np.max(mi_matrix)}")
    else:
        print("No valid samples for analysis.")
    
    # Only save CSVs if not in collect_only mode
    if not args.collect_only:
        # Save full analysis results to single CSV with global + layer rows
        analysis_csv_path = os.path.join(args.output_csv, f"path/to/your/file")
        os.makedirs(os.path.dirname(analysis_csv_path), exist_ok=True)

        # Define fieldnames manually to match your expected CSV layout
        fieldnames = [
            "layer",
            "H_X", "H_Y", "H_joint_XY", "I_XY",
            "H_X_prev", "H_joint_XprevY", "I_XprevY",
            "H_Z_final", "H_joint_ZY", "I_ZY", "I_ZY/H_Z"
        ]

        with open(analysis_csv_path, mode='w', newline='') as csv_file:
            writer = csv.DictWriter(csv_file, fieldnames=fieldnames)
            writer.writeheader()

            # === Global row ===
            global_row = {key: aggregated_analysis.get(key, "") for key in fieldnames[1:8]}
            global_row["layer"] = "global"
            writer.writerow(global_row)

            # === Per-layer rows ===
            layer_results = aggregated_analysis.get("layer_results", {})
            for layer_name in sorted(layer_results.keys(), key=lambda x: int(x.split("_")[1])):
                row = {key: "" for key in fieldnames}
                row["layer"] = layer_name
                row.update(layer_results[layer_name])
                
                # Compute I_ZY / H_Z_final if both values are present
                try:
                    izy = float(row["I_ZY"])
                    hzf = float(row["H_Z_final"])
                    row["I_ZY/H_Z"] = izy / hzf if hzf != 0 else ""
                except (ValueError, KeyError):
                    row["I_ZY/H_Z"] = ""

                writer.writerow(row)

        print(f"Saved formatted information flow metrics to {analysis_csv_path}")

        # Save delta hidden states MI matrix to separate CSV
        if delta_hidden_mi_results and "df" in delta_hidden_mi_results:
            # Create directory for delta MI results
            delta_mi_dir = os.path.join(args.output_csv, "path/to/your/file")
            os.makedirs(delta_mi_dir, exist_ok=True)
            
            # Save the DataFrame (this contains the numerical values of mutual information)
            delta_mi_csv_path = os.path.join(delta_mi_dir, f"path/to/your/file")
            delta_hidden_mi_results["df"].to_csv(delta_mi_csv_path)
            print(f"Saved delta hidden states MI matrix to {delta_mi_csv_path}")

            # Also save the description DataFrame if available
            # This contains the formula for what each cell represents, like "I(h_j-h_i; Y)"
            if "description_df" in delta_hidden_mi_results and delta_hidden_mi_results["description_df"] is not None:
                desc_csv_path = os.path.join(delta_mi_dir, f"path/to/your/file")
                delta_hidden_mi_results["description_df"].to_csv(desc_csv_path)
                print(f"Saved delta hidden states MI descriptions to {desc_csv_path}")

    # === FREE GPU MEMORY AFTER EACH EPOCH ===
    # Delete large variables and call garbage collection and empty_cache.
    del hidden_states_y, hidden_states_z, input_ids, lm_head_weight
    if 'delta_hidden_mi_results' in locals():
        del delta_hidden_mi_results
    if 'aggregated_analysis' in locals():
        del aggregated_analysis
    gc.collect()
    torch.cuda.empty_cache()
    
    return mi_matrix

def main():
    parser = argparse.ArgumentParser(description="Run analysis over multiple epochs.")
    parser.add_argument("--input_file", type=str, default="",
                        help="File with prompts; if empty, use CLUTRR dataset.")
    parser.add_argument("--checkpoint_base", type=str, required=True,
                        help="Base directory for model checkpoints; assumes subdirectories epoch_1, ..., epoch_5")
    parser.add_argument("--sigma", type=float, default=1.0, help="Sigma for Gaussian kernel.")
    parser.add_argument("--seed", type=int, default=42, help="Random seed.")
    parser.add_argument("--output_csv", type=str, default="path/to/your/file",
                        help="Directory to save information flow CSV results.")
    parser.add_argument("--fig_dir", type=str, default="path/to/your/file",
                        help="Directory to save figures.")
    parser.add_argument("--delta_h_dir", type=str, default="path/to/your/file",
                        help="Directory to save delta hidden states analysis results.")
    parser.add_argument("--pred_csv", type=str, default="path/to/your/file",
                        help="Directory to save layer predictions CSV.")
    parser.add_argument("--epoch_start", type=int, default=0,
                        help="Starting epoch for analysis.")
    parser.add_argument("--epoch_end", type=int, default=0,
                        help="Ending epoch for analysis.")
    parser.add_argument("--epoch_step", type=int, default=0,
                        help="Step size between epochs for analysis.")
    parser.add_argument("--consistent_colorbar", action="store_true",
                        help="Use consistent colorbar scale across all epoch visualizations.")
    args = parser.parse_args()
    
    # Add a temporary flag for collecting data without saving visualizations
    args.collect_only = False
    
    set_seed(args.seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Initialize tokenizer using the checkpoint of the first epoch.
    tokenizer = AutoTokenizer.from_pretrained(os.path.join(args.checkpoint_base, f"path/to/your/file"))
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        
    # Print configuration information
    print(f"Configuration:")
    print(f"  Input file: {args.input_file}")
    print(f"  Checkpoint base: {args.checkpoint_base}")
    print(f"  Sigma: {args.sigma}")
    print(f"  Output CSV: {args.output_csv}")
    print(f"  Figure directory: {args.fig_dir}")
    print(f"  Delta hidden states directory: {args.delta_h_dir}")
    print(f"  Analyzing epochs: {args.epoch_start} to {args.epoch_end} with step {args.epoch_step}")
    print(f"  Using consistent colorbar: {args.consistent_colorbar}")
    
    # Create output directories for delta hidden states analysis
    delta_h_dir = args.delta_h_dir
    os.makedirs(delta_h_dir, exist_ok=True)
    delta_h_csv_dir = os.path.join(args.output_csv, "path/to/your/file")
    os.makedirs(delta_h_csv_dir, exist_ok=True)
    
    # Create main output and figure directories
    os.makedirs(args.output_csv, exist_ok=True)
    os.makedirs(args.fig_dir, exist_ok=True)
    os.makedirs(args.pred_csv, exist_ok=True)
    
    # Define epochs to analyze
    epochs = list(range(args.epoch_start, args.epoch_end + 1, args.epoch_step))
    
    # If using consistent colorbar, first gather all MI matrices to find global min and max
    global_vmin, global_vmax = None, None
    
    if args.consistent_colorbar:
        print("First pass: collecting MI matrices to determine global min/max for consistent colorbar...")
        all_mi_matrices = []
        
        # Set collect_only flag to True for first pass
        args.collect_only = True
        
        for epoch in epochs:
            epoch_checkpoint = os.path.join(args.checkpoint_base, f"path/to/your/file")
            print(f"Collecting data from {epoch_checkpoint}")
            
            # Run analysis without saving visualizations
            mi_matrix = run_epoch_analysis(
                epoch_checkpoint=epoch_checkpoint, 
                args=args,
                device=device,
                tokenizer=tokenizer,
                vmin=None,  # Don't set vmin/vmax in this first pass
                vmax=None
            )
            
            if mi_matrix is not None and mi_matrix.size > 0:  # Check that matrix is valid and not empty
                all_mi_matrices.append(mi_matrix)
                print(f"Successfully added MI matrix to collection (total: {len(all_mi_matrices)})")
        
        # Compute global min and max if we have matrices
        if all_mi_matrices:
            global_vmin = float(min(np.min(matrix) for matrix in all_mi_matrices))
            global_vmax = float(max(np.max(matrix) for matrix in all_mi_matrices))
            print(f"Global min: {global_vmin}, Global max: {global_vmax}")
        else:
            print("No valid MI matrices found in first pass. Using automatic scaling.")
        
        # Reset collect_only flag for second pass
        args.collect_only = False
    
    # Now run the actual analysis with proper visualizations
    print("Running full analysis with visualizations...")
    for epoch in epochs:
        epoch_checkpoint = os.path.join(args.checkpoint_base, f"path/to/your/file")
        run_epoch_analysis(
            epoch_checkpoint=epoch_checkpoint, 
            args=args,
            device=device,
            tokenizer=tokenizer,
            vmin=global_vmin,
            vmax=global_vmax
        )
        
if __name__ == "__main__":
    main()